{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# <center>保存模型</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 概述\n",
    "在模型训练过程中，可以添加检查点(CheckPoint)用于保存模型的参数，以便进行推理及再训练使用。如果想继续在不同硬件平台上做推理，可通过网络和CheckPoint格式文件生成对应的MindIR、AIR和ONNX格式文件。\n",
    "\n",
    "- MindIR：MindSpore的一种基于图表示的函数式IR，其最核心的目的是服务于自动微分变换，目前可用于MindSpore Lite端侧推理。\n",
    "\n",
    "- CheckPoint：MindSpore的存储了所有训练参数值的二进制文件。采用了Google的Protocol Buffers机制，与开发语言、平台无关，具有良好的可扩展性。CheckPoint的protocol格式定义`在mindspore/ccsrc/utils/checkpoint.proto`中。\n",
    "\n",
    "- AIR：全称Ascend Intermediate Representation，类似ONNX，是华为定义的针对机器学习所设计的开放式的文件格式，能更好地适配Ascend AI处理器。\n",
    "\n",
    "- ONNX：全称Open Neural Network Exchange，是一种针对机器学习所设计的开放式的文件格式，用于存储训练好的模型。\n",
    "\n",
    "以下通过图片分类应用示例来介绍保存CheckPoint格式文件和导出MindIR、AIR和ONNX格式文件的方法。\n",
    "\n",
    "> 本文档适用于CPU、GPU和Ascend AI处理器环境。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "说明：<br/>在保存和转换模型前，我们需要完整进行图片分类训练，包含数据准备、定义网络、定义损失函数及优化器和训练网络，详细信息可以参考：https://gitee.com/mindspore/docs/blob/master/tutorials/notebook/mindspore_quick_start.ipynb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "整体流程如下：\n",
    "\n",
    "1. 数据准备\n",
    "\n",
    "2. 构造神经网络\n",
    "\n",
    "3. 搭建训练网络、定义损失函数及优化器\n",
    "\n",
    "4. 保存CheckPoint格式文件\n",
    "\n",
    "5. 导出不同格式文件\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据准备"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 下载MNIST数据集"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "运行以下命令来获取数据集："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-11-30 16:14:55--  https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/MNIST_Data.zip\n",
      "Resolving proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)... 192.168.0.172\n",
      "Connecting to proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)|192.168.0.172|:8083... connected.\n",
      "Proxy request sent, awaiting response... 200 OK\n",
      "Length: 10754903 (10M) [application/zip]\n",
      "Saving to: ‘MNIST_Data.zip’\n",
      "\n",
      "MNIST_Data.zip      100%[===================>]  10.26M  64.9MB/s    in 0.2s    \n",
      "\n",
      "2020-11-30 16:14:56 (64.9 MB/s) - ‘MNIST_Data.zip’ saved [10754903/10754903]\n",
      "\n",
      "Archive:  MNIST_Data.zip\n",
      "   creating: MNIST_Data/test/\n",
      "  inflating: MNIST_Data/test/t10k-images-idx3-ubyte  \n",
      "  inflating: MNIST_Data/test/t10k-labels-idx1-ubyte  \n",
      "   creating: MNIST_Data/train/\n",
      "  inflating: MNIST_Data/train/train-images-idx3-ubyte  \n",
      "  inflating: MNIST_Data/train/train-labels-idx1-ubyte  \n",
      "./datasets/MNIST_Data\n",
      "├── test\n",
      "│   ├── t10k-images-idx3-ubyte\n",
      "│   └── t10k-labels-idx1-ubyte\n",
      "└── train\n",
      "    ├── train-images-idx3-ubyte\n",
      "    └── train-labels-idx1-ubyte\n",
      "\n",
      "2 directories, 4 files\n"
     ]
    }
   ],
   "source": [
    "!wget -N https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/datasets/MNIST_Data.zip\n",
    "!unzip -o MNIST_Data.zip -d ./datasets\n",
    "!tree ./datasets/MNIST_Data/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据处理："
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "数据集对于训练非常重要，好的数据集可以有效提高训练精度和效率，在加载数据集前，我们通常会对数据集进行一些处理。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 定义数据集及数据操作\n",
    "我们定义一个函数`create_dataset`来创建数据集。在这个函数中，我们定义好需要进行的数据增强和处理操作：\n",
    "1. 定义数据集。\n",
    "2. 定义进行数据增强和处理所需要的一些参数。\n",
    "3. 根据参数，生成对应的数据增强操作。\n",
    "4. 使用`map`映射函数，将数据操作应用到数据集。\n",
    "5. 对生成的数据集进行处理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-04T06:46:31.263831Z",
     "start_time": "2020-09-04T06:46:31.242077Z"
    }
   },
   "outputs": [],
   "source": [
    "import mindspore.dataset.vision.c_transforms as CV\n",
    "import mindspore.dataset.transforms.c_transforms as C\n",
    "from mindspore.dataset.vision import Inter\n",
    "from mindspore import dtype as mstype\n",
    "import matplotlib\n",
    "import mindspore.dataset as ds\n",
    "\n",
    "\n",
    "def create_dataset(data_path, batch_size=32, repeat_size=1,\n",
    "                   num_parallel_workers=1):\n",
    "    \"\"\" \n",
    "    create dataset for train or test\n",
    "    \n",
    "    Args:\n",
    "        data_path (str): Data path\n",
    "        batch_size (int): The number of data records in each group\n",
    "        repeat_size (int): The number of replicated data records\n",
    "        num_parallel_workers (int): The number of parallel workers\n",
    "    \"\"\"\n",
    "    # define dataset\n",
    "    mnist_ds = ds.MnistDataset(data_path)\n",
    "\n",
    "    # define some parameters needed for data enhancement and rough justification\n",
    "    resize_height, resize_width = 32, 32\n",
    "    rescale = 1.0 / 255.0\n",
    "    shift = 0.0\n",
    "    rescale_nml = 1 / 0.3081\n",
    "    shift_nml = -1 * 0.1307 / 0.3081\n",
    "\n",
    "    # according to the parameters, generate the corresponding data enhancement method\n",
    "    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)\n",
    "    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)\n",
    "    rescale_op = CV.Rescale(rescale, shift)\n",
    "    hwc2chw_op = CV.HWC2CHW()\n",
    "    type_cast_op = C.TypeCast(mstype.int32)\n",
    "\n",
    "    # using map to apply operations to a dataset\n",
    "    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns=\"label\", num_parallel_workers=num_parallel_workers)\n",
    "    mnist_ds = mnist_ds.map(operations=resize_op, input_columns=\"image\", num_parallel_workers=num_parallel_workers)\n",
    "    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns=\"image\", num_parallel_workers=num_parallel_workers)\n",
    "    mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns=\"image\", num_parallel_workers=num_parallel_workers)\n",
    "    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns=\"image\", num_parallel_workers=num_parallel_workers)\n",
    "    \n",
    "    # process the generated dataset\n",
    "    buffer_size = 10000\n",
    "    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)\n",
    "    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)\n",
    "    mnist_ds = mnist_ds.repeat(repeat_size)\n",
    "\n",
    "    return mnist_ds\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 构造神经网络"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在对手写字体识别上，通常采用卷积神经网络架构（CNN）进行学习预测，最经典的属1998年由Yann LeCun创建的LeNet5架构，在构建LeNet5前，我们需要对全连接层以及卷积层进行初始化。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-04T06:46:32.448830Z",
     "start_time": "2020-09-04T06:46:31.265357Z"
    }
   },
   "outputs": [],
   "source": [
    "import mindspore.nn as nn\n",
    "from mindspore.common.initializer import Normal\n",
    "\n",
    "class LeNet5(nn.Cell):\n",
    "    \"\"\"Lenet network structure.\"\"\"\n",
    "    # define the operator required\n",
    "    def __init__(self, num_class=10, num_channel=1):\n",
    "        super(LeNet5, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')\n",
    "        self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))\n",
    "        self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))\n",
    "        self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))\n",
    "        self.relu = nn.ReLU()\n",
    "        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.flatten = nn.Flatten()\n",
    "\n",
    "    # use the preceding operators to construct networks\n",
    "    def construct(self, x):\n",
    "        x = self.max_pool2d(self.relu(self.conv1(x)))\n",
    "        x = self.max_pool2d(self.relu(self.conv2(x)))\n",
    "        x = self.flatten(x)\n",
    "        x = self.relu(self.fc1(x))\n",
    "        x = self.relu(self.fc2(x))\n",
    "        x = self.fc3(x) \n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 搭建训练网络"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构建完成神经网络后，就可以着手进行训练网络的构建，包括定义损失函数及优化器。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-04T06:46:57.649137Z",
     "start_time": "2020-09-04T06:46:33.811666Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "========== The Training Model is Defined. ==========\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from mindspore.nn import SoftmaxCrossEntropyWithLogits\n",
    "from mindspore.nn import Accuracy\n",
    "from mindspore import context, Model\n",
    "\n",
    "context.set_context(mode=context.GRAPH_MODE, device_target=\"CPU\")\n",
    "\n",
    "lr = 0.01\n",
    "momentum = 0.9 \n",
    "\n",
    "# create the network\n",
    "network = LeNet5()\n",
    "\n",
    "# define the optimizer\n",
    "net_opt = nn.Momentum(network.trainable_params(), lr, momentum)\n",
    "\n",
    "# define the loss function\n",
    "net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')\n",
    "\n",
    "# define the model\n",
    "model = Model(network, net_loss, net_opt, metrics={\"Accuracy\": Accuracy()} )\n",
    "\n",
    "epoch_size = 1\n",
    "mnist_path = \"./datasets/MNIST_Data\"\n",
    "\n",
    "eval_dataset = create_dataset(\"./datasets/MNIST_Data/test\")\n",
    "\n",
    "repeat_size = 1\n",
    "print(\"========== The Training Model is Defined. ==========\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 保存CheckPoint格式文件"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在模型训练的过程中，使用Callback机制传入回调函数`ModelCheckpoint`对象，可以保存模型参数，生成CheckPoint文件。\n",
    "\n",
    "通过`CheckpointConfig`对象可以设置CheckPoint的保存策略。保存的参数分为网络参数和优化器参数。\n",
    "\n",
    "`ModelCheckpoint`提供默认配置策略，方便用户快速上手，用户可以根据具体需求对`CheckPoint`策略进行配置。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 单次运行训练脚本保存模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在这里配置`CheckPoint`时，设置的是每隔375个steps就保存一次，最多保留10个CheckPoint文件，生成前缀名为“lenet”，具体用法如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "========== The Training is Starting. ==========\n",
      "========== The Training is Completed and the Checkpoint Files are Saved. ==========\n"
     ]
    }
   ],
   "source": [
    "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig\n",
    "\n",
    "model_path = './models/ckpt/mindspore_save_model/'\n",
    "# clean up old run files before in Linux\n",
    "os.system('rm -f {}*.ckpt {}*.meta {}*.pb'.format(model_path, model_path, model_path))\n",
    "\n",
    "# define config_ck for specifying the steps to save the checkpoint and the maximum file numbers\n",
    "config_ck = CheckpointConfig(save_checkpoint_steps=375, keep_checkpoint_max=10)\n",
    "# define ckpoint_cb for specifying the prefix of the file and the saving directory\n",
    "ckpoint_cb = ModelCheckpoint(prefix='lenet', directory=model_path, config=config_ck)\n",
    "#load the training dataset\n",
    "ds_train = create_dataset(os.path.join(mnist_path, \"train\"), 32, repeat_size)\n",
    "print(\"========== The Training is Starting. ==========\")\n",
    "model.train(epoch_size, ds_train, callbacks=ckpoint_cb, dataset_sink_mode=False)\n",
    "print(\"========== The Training is Completed and the Checkpoint Files are Saved. ==========\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上述代码中，首先需要初始化一个CheckpointConfig类对象，用来设置保存策略。\n",
    "\n",
    "- `save_checkpoint_steps`表示每隔多少个step保存一次。\n",
    "- `keep_checkpoint_max`表示最多保留CheckPoint文件的数量。\n",
    "- `prefix`表示生成CheckPoint文件的前缀名。\n",
    "- `directory`表示存放文件的目录。\n",
    "- `epoch_size`表示每个epoch需要遍历完成图片的batch数。\n",
    "- `ds_train`表示数据集。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "创建一个ModelCheckpoint对象把它传递给`model.train`方法，就可以在训练过程中使用CheckPoint功能了。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "生成的CheckPoint文件如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./models/ckpt/mindspore_save_model\n",
      "├── lenet-1_1125.ckpt\n",
      "├── lenet-1_1500.ckpt\n",
      "├── lenet-1_1875.ckpt\n",
      "├── lenet-1_375.ckpt\n",
      "├── lenet-1_750.ckpt\n",
      "└── lenet-graph.meta\n",
      "\n",
      "0 directories, 6 files\n"
     ]
    }
   ],
   "source": [
    "! tree ./models/ckpt/mindspore_save_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "其中：\n",
    "- `lenet-graph.meta`为编译后的计算图。\n",
    "- CheckPoint文件后缀名为`.ckpt`，文件的命名方式表示保存参数所在的epoch和step数。\n",
    "- `lenet-1_750.ckpt`表示保存的是第1个epoch的第750个step的模型参数。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 多次运行训练脚本保存模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果用户使用相同的前缀名，运行多次训练脚本，可能会生成同名CheckPoint文件。MindSpore为方便用户区分每次生成的文件，会在用户定义的前缀后添加_和数字加以区分。如下所示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "========== The Training is Starting. ==========\n",
      "========== The Training is Completed and the Checkpoint Files are Saved. ==========\n"
     ]
    }
   ],
   "source": [
    "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig\n",
    "import os\n",
    "\n",
    "# clean up old run files before in Linux\n",
    "os.system('rm -f {}lenet_2*.ckpt'.format(model_path))\n",
    "\n",
    "config_ck = CheckpointConfig(save_checkpoint_steps=375, keep_checkpoint_max=10)\n",
    "# Specify that here the script is executed for the second time\n",
    "ckpoint_cb = ModelCheckpoint(prefix='lenet_2', directory='./models/ckpt/mindspore_save_model', config=config_ck)\n",
    "ds_train = create_dataset(os.path.join(mnist_path, \"train\"), 32, repeat_size)\n",
    "print(\"========== The Training is Starting. ==========\")\n",
    "model.train(epoch_size, ds_train, callbacks=ckpoint_cb,dataset_sink_mode=False)\n",
    "print(\"========== The Training is Completed and the Checkpoint Files are Saved. ==========\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "生成的CheckPoint文件（以`lenet-2`为前缀的`.ckpt`文件）如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./models/ckpt/mindspore_save_model\n",
      "├── lenet-1_1125.ckpt\n",
      "├── lenet-1_1500.ckpt\n",
      "├── lenet-1_1875.ckpt\n",
      "├── lenet-1_375.ckpt\n",
      "├── lenet-1_750.ckpt\n",
      "├── lenet_2-1_1125.ckpt\n",
      "├── lenet_2-1_1500.ckpt\n",
      "├── lenet_2-1_1875.ckpt\n",
      "├── lenet_2-1_375.ckpt\n",
      "├── lenet_2-1_750.ckpt\n",
      "├── lenet_2-graph.meta\n",
      "└── lenet-graph.meta\n",
      "\n",
      "0 directories, 12 files\n"
     ]
    }
   ],
   "source": [
    "! tree ./models/ckpt/mindspore_save_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "lenet_2-1_750.ckpt 表示本次运行脚本生成的第1个epoch的第750个step的CheckPoint文件。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 配置时间策略保存模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MindSpore提供了两种保存CheckPoint策略：迭代策略和时间策略，上述代码为迭代策略。我们可以通过创建`CheckpointConfig`对象设置相应策略，CheckpointConfig中共有四个参数可以设置：\n",
    "\n",
    "- `save_checkpoint_steps`：表示每隔多少个step保存一个CheckPoint文件，默认值为1。\n",
    "\n",
    "- `save_checkpoint_seconds`：表示每隔多少秒保存一个CheckPoint文件，默认值为0。\n",
    "\n",
    "- `keep_checkpoint_max`：表示最多保存多少个CheckPoint文件，默认值为5。\n",
    "\n",
    "- `keep_checkpoint_per_n_minutes`：表示每隔多少分钟保留一个CheckPoint文件，默认值为0。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "因为时间策略需要训练时间稍长一些，所以这里把`epoch_size`改为10。\n",
    "\n",
    "以下代码为时间策略："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "========== The Training is Starting. ==========\n",
      "========== The Training is Completed and the Checkpoint Files are Saved. ==========\n"
     ]
    }
   ],
   "source": [
    "from mindspore.train.callback import ModelCheckpoint, CheckpointConfig\n",
    "import os\n",
    "\n",
    "os.system('rm -f {}lenet_3*.ckpt'.format(model_path))\n",
    "# define config_ck for specifying the seconds to save the checkpoint and the maximum file numbers\n",
    "config_ck = CheckpointConfig(save_checkpoint_steps=None, save_checkpoint_seconds=10, keep_checkpoint_max=None, keep_checkpoint_per_n_minutes=1)\n",
    "# define ckpoint_cb for specifying the prefix of the file and the saving directory\n",
    "ckpoint_cb = ModelCheckpoint(prefix='lenet_3', directory='./models/ckpt/mindspore_save_model', config=config_ck)\n",
    "#load the training dataset\n",
    "epoch_size = 2\n",
    "ds_train = create_dataset(os.path.join(mnist_path, \"train\"), 32, repeat_size)\n",
    "print(\"========== The Training is Starting. ==========\")\n",
    "model.train(epoch_size, ds_train, callbacks=ckpoint_cb,dataset_sink_mode=False)\n",
    "print(\"========== The Training is Completed and the Checkpoint Files are Saved. ==========\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里`save_checkpoint_seconds`和`keep_checkpoint_per_n_minutes`这两个分别设置为10和1。<br>\n",
    "意思是每10秒会保存一个CheckPoint文件，每隔1分钟会保留一个CheckPoint文件。假设训练持续了1分钟，那总共会生成7个CheckPoint文件，但是当训练结束后，实际上会看到4个CheckPoint文件（以`lenet-3`为前缀的`.ckpt`文件），即保存下来的3个文件和默认保存最后一个step的CheckPoint文件。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "生成的CheckPoint文件如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./models/ckpt/mindspore_save_model\n",
      "├── lenet-1_1125.ckpt\n",
      "├── lenet-1_1500.ckpt\n",
      "├── lenet-1_1875.ckpt\n",
      "├── lenet-1_375.ckpt\n",
      "├── lenet-1_750.ckpt\n",
      "├── lenet_2-1_1125.ckpt\n",
      "├── lenet_2-1_1500.ckpt\n",
      "├── lenet_2-1_1875.ckpt\n",
      "├── lenet_2-1_375.ckpt\n",
      "├── lenet_2-1_750.ckpt\n",
      "├── lenet_2-graph.meta\n",
      "├── lenet_3-1_1023.ckpt\n",
      "├── lenet_3-2_1254.ckpt\n",
      "├── lenet_3-2_1875.ckpt\n",
      "├── lenet_3-2_194.ckpt\n",
      "├── lenet_3-graph.meta\n",
      "└── lenet-graph.meta\n",
      "\n",
      "0 directories, 17 files\n"
     ]
    }
   ],
   "source": [
    "! tree ./models/ckpt/mindspore_save_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "另请注意，如果想要删除.ckpt文件时，请同步删除.meta 文件。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 导出MindIR格式文件"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "当有了CheckPoint文件后，如果想继续在MindSpore Lite端侧做推理，需要通过网络和CheckPoint生成对应的MindIR格式模型文件。当前支持基于静态图，且不包含控制流语义的推理网络导出。建议使用`.mindir`作为MINDIR格式文件的后缀名。导出该格式文件的代码如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore import export, load_checkpoint, load_param_into_net\n",
    "from mindspore import Tensor\n",
    "import numpy as np\n",
    "\n",
    "lenet = LeNet5()\n",
    "# return a parameter dict for model\n",
    "param_dict = load_checkpoint(\"./models/ckpt/mindspore_save_model/lenet-1_1875.ckpt\")\n",
    "# load the parameter into net\n",
    "load_param_into_net(lenet, param_dict)\n",
    "input = np.random.uniform(0.0, 1.0, size=[32, 1, 32, 32]).astype(np.float32)\n",
    "# export the file with the specified name and format\n",
    "export(lenet, Tensor(input), file_name='lenet-1_1875', file_format='MINDIR',)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "完成以后，在当前目录下会生成一个MindIR格式文件，文件名为：`lenet-1_1875.mindir`。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 导出ONNX格式文件"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "当有了CheckPoint文件后，如果想继续在Ascend AI处理器、GPU、CPU等多种硬件上做推理，需要通过网络和CheckPoint生成对应的ONNX格式模型文件，建议使用`.onnx`作为ONNX格式文件的后缀名。导出该格式文件的代码样例如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore.train.serialization import export, load_checkpoint, load_param_into_net\n",
    "from mindspore import Tensor\n",
    "import numpy as np\n",
    "lenet = LeNet5()\n",
    "# return a parameter dict for model\n",
    "param_dict = load_checkpoint(\"./models/ckpt/mindspore_save_model/lenet-1_1875.ckpt\")\n",
    "# load the parameter into net\n",
    "load_param_into_net(lenet, param_dict)\n",
    "input = np.random.uniform(0.0, 1.0, size=[32, 1, 32, 32]).astype(np.float32)\n",
    "# export the file with the specified name and format\n",
    "export(lenet, Tensor(input), file_name='lenet-1_1875', file_format='ONNX')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "完成以后，在当前目录下会生成一个ONNX格式文件，文件名为：`lenet-1_1875.onnx`。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 总结"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以上就是保存模型并导出文件的全部体验过程，我们通过本次体验全面了解了训练模型的保存以及如何导出成为不同格式的文件，以便用于不同平台上的推理。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore-1.0.1",
   "language": "python",
   "name": "mindspore-1.0.1"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
